# import packages
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import statsmodels.api as sm
import statsmodels.formula.api as smf

# return taxpayer_ids that are audited by both models, just the first model, or just the second model 
def get_taxpayer_id_list(df1, df2):
    df1_selected_taxpayer_ids = set(df1.taxpayer_id[df1.aud_ind == 1])
    df2_selected_taxpayer_ids = set(df2.taxpayer_id[df2.aud_ind == 1])
    df1_only_taxpayer_ids = df1_selected_taxpayer_ids.difference(df2_selected_taxpayer_ids)
    df2_only_taxpayer_ids = df2_selected_taxpayer_ids.difference(df1_selected_taxpayer_ids)
    both_taxpayer_ids = df1_selected_taxpayer_ids.intersection(df2_selected_taxpayer_ids)
    return df1_only_taxpayer_ids, df2_only_taxpayer_ids, both_taxpayer_ids
  
## function to get weighted mean
def w_avg(df, var, weight):
    d = df[var]
    w = df[weight]
    return (d * w).sum() / w.sum()

## linear estimator
def regEstimate(dataset, pbvarb, outcome, wvarb):
    model = smf.wls(outcome + '~' + pbvarb, dataset, weights = dataset[wvarb]).fit(cov_type = 'HC1')
    coef = model.params[pbvarb]
    se = model.bse[pbvarb]
    return coef,se

## probabilistic estimator
def chenEstimate(dataset, pbvarb, outcome, wvarb):
    black_audit_rate = (dataset[pbvarb]*dataset[outcome]*dataset[wvarb]).sum()/(dataset[pbvarb]*dataset[wvarb]).sum()
    nonblack_audit_rate = ((1-dataset[pbvarb])*dataset[outcome]*dataset[wvarb]).sum()/((1-dataset[pbvarb])*dataset[wvarb]).sum()
    est = black_audit_rate - nonblack_audit_rate
    return est

## weighted variance
def getWVar(values, weights):
    average = np.average(values, weights = weights)
    variance = np.average((values-average)**2, weights = weights)
    return variance

## standard error multiplier to obtain probabilistic standard error
def getSEMultiplier(dataset, pbvarb, wvarb):
    return np.sqrt(getWVar(dataset[pbvarb], dataset[wvarb])/((dataset[pbvarb] * dataset[wvarb]).mean() * ((1-dataset[pbvarb]) * dataset[wvarb]).mean()))

## get standard errors for each model.
def getSEs(dataset, pbvarb, outcome, wvarb):
    seMultiplier = getSEMultiplier(dataset, pbvarb, wvarb)
    seReg = regEstimate(dataset,pbvarb,outcome,wvarb)[1]
    seChen = seReg * seMultiplier
    return seChen, seReg

# build 2 column table
def get_outcome_table(research_audits, df1, df2, outcome_list, overlap = False):
    df1_taxpayer_id = []
    df2_taxpayer_id = []
    both_taxpayer_id = []
    ## if overlap = True, get list of taxpayer_ids for each model, and union of models.
    if overlap:
        df1_taxpayer_id, df2_taxpayer_id, both_taxpayer_id = get_taxpayer_id_list(df1, df2)
    else:
        df1_taxpayer_id = set(df1.taxpayer_id[df1.aud_ind == 1])
        df2_taxpayer_id = set(df2.taxpayer_id[df2.aud_ind == 1])
    ## number of rows
    if overlap:
        n_row = ['Proportion of taxpayer_ids', len(df1_taxpayer_id)/len(research_audits), len(df2_taxpayer_id)/len(research_audits), len(both_taxpayer_id)/len(research_audits)]
        n_wgt = ['Share Selected', research_audits[research_audits.taxpayer_id.isin(df1_taxpayer_id)]['base_weight'].sum()/research_audits['base_weight'].sum(), research_audits[research_audits.taxpayer_id.isin(df2_taxpayer_id)]['base_weight'].sum()/research_audits['base_weight'].sum(), research_audits[research_audits.taxpayer_id.isin(both_taxpayer_id)]['base_weight'].sum()/research_audits['base_weight'].sum()]
    else:
        n_row = ['Proportion of taxpayer_ids', len(df1_taxpayer_id)/len(research_audits), len(df2_taxpayer_id)/len(research_audits)]
        n_wgt = ['Share Selected', research_audits[research_audits.taxpayer_id.isin(df1_taxpayer_id)]['base_weight'].sum()/research_audits['base_weight'].sum(), research_audits[research_audits.taxpayer_id.isin(df2_taxpayer_id)]['base_weight'].sum()/research_audits['base_weight'].sum()]
    ## disparity estimates - audit = 1 if selected by model, 0 otherwise.    
    lin_est = ['Linear Disparity']
    prob_est = ['Probabilistic Disparity']    
    lin_se = ['Linear SE']    
    prob_se = ['Probabilistic SE'] 
    if overlap:
        model_taxpayer_id_list = [df1_taxpayer_id, df2_taxpayer_id, both_taxpayer_id]
    else:
        model_taxpayer_id_list = [df1_taxpayer_id, df2_taxpayer_id]
    for taxpayer_id_cond in model_taxpayer_id_list:
        temp = research_audits.copy()
        temp['aud_ind'] = np.where(temp.taxpayer_id.isin(taxpayer_id_cond), 1, 0) 
        lin = regEstimate(temp, 'predicted_prob_black', 'aud_ind', 'base_weight')[0]
        prob = chenEstimate(temp, 'predicted_prob_black', 'aud_ind', 'base_weight')
        probSE, linSE = getSEs(temp, 'predicted_prob_black', 'aud_ind', 'base_weight')
        lin_est.append(lin)
        prob_est.append(prob)
        lin_se.append(linSE)
        prob_se.append(probSE)  
    ## get mean of outcome variables for each model, and union of models (if overlap).        
    df1_mean = [w_avg(research_audits[research_audits.taxpayer_id.isin(df1_taxpayer_id)], col, 'base_weight') for col in outcome_list]
    df2_mean = [w_avg(research_audits[research_audits.taxpayer_id.isin(df2_taxpayer_id)], col, 'base_weight') for col in outcome_list]
    if overlap:
        both_mean = [w_avg(research_audits[research_audits.taxpayer_id.isin(both_taxpayer_id)], col, 'base_weight') for col in outcome_list]
        outcome_table = pd.DataFrame(list(zip(outcome_list, df1_mean, df2_mean, both_mean)),
                   columns =['Outcome', 'df1_only', 'df2_only', 'union'])
    else:
        outcome_table = pd.DataFrame(list(zip(outcome_list, df1_mean, df2_mean)),
                   columns =['Outcome', 'df1', 'df2'])
    for i in [prob_est, lin_est]:
        outcome_table.loc[len(outcome_table.index)] = i
    return outcome_table

# read in main RF dataset and dataset of Schedule C variables, then merge    
full_research_audits_df = pd.read_csv('/REDACTED/fairness/code/rf/data/clean_rf_data_plus_dep_database.csv')
full_research_audits_df = full_research_audits_df[(full_research_audits_df.activity_code == 270) | (full_research_audits_df.activity_code == 271)]
sch_c = pd.read_csv('/REDACTED/data/raw/research_audits_sch_c_varbs.csv')

sch_c = sch_c.rename(columns={'taxpayer_id':'taxpayer_id_new', 'study_year_c':'study_year'})
full_research_audits_df = full_research_audits_df.merge(sch_c, on=['taxpayer_id_new', 'study_year'], how='left')
full_research_audits_df.columns = [x.lower() for x in full_research_audits_df.columns]

# potential to fill in na's (from failed merge) with 0
#full_research_audits_df[['net_income_amt', 'net_income_amt_cor', 'gross_incm_amt', 'gross_incm_amt_cor', 'tot_expns_amt', 'tot_expns_amt_cor']] = full_research_audits_df[['net_income_amt', 'net_income_amt_cor', 'gross_incm_amt', 'gross_incm_amt_cor', 'tot_expns_amt', 'tot_expns_amt_cor']].fillna(value=0)

# read in csvs outputted during model runs for refundable oracle, refundable regressor, underreportaxpayer_idg oracle, and underreportaxpayer_idg regressor
df1 = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_oracle_selected_taxpayer_ids_alt_outcome.csv')
df2 = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_reg_selected_taxpayer_ids_alt_outcome.csv')    
df3 = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_oracle_selected_taxpayer_ids.csv')
df4 = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_reg_selected_taxpayer_ids.csv') 

# define outcomes of interest for our table
full_research_audits_df['chg_in_tax_owed_ind_0'] = np.where(full_research_audits_df.chg_in_tax_owed_pv > 0, 1, 0)
full_research_audits_df['ref_cred_ind_0'] = np.where(full_research_audits_df.ref_cred_amt_dif_pv > 0, 1, 0)
full_research_audits_df['net_income_adj_c'] = full_research_audits_df.net_income_amt_cor - full_research_audits_df.net_income_amt
full_research_audits_df['net_income_adj_c_ind'] = np.where(full_research_audits_df.net_income_adj_c > 0, 1, 0*full_research_audits_df.net_income_adj_c)

outcome_list = ['chg_in_tax_owed_ind_0', 'chg_in_tax_owed_pv', 'ref_cred_ind_0', 'ref_cred_amt_dif_pv', 'dep_reduced_ind', 'changed_from_hoh', 'net_income_adj_c_ind']

######## FIRST HALF, oracle and refundable oracle columns ############

outcome_table = get_outcome_table(full_research_audits_df, df3, df1, outcome_list, overlap = False)
outcome_table = outcome_table.round(5)

## names associated with outcome_list (rows in final table)
outcome_str_desc_list = ['Change Rate', 'Mean Underreportaxpayer_idg ($)', 'Refundable Credit Change Rate', 'Mean Refundable Credit Overclaiming ($)', 'Share with Dependent(s) Reduced', 'Adjustment from HOH', 'Business Income Change Rate']
## name of models (columns in final table)
model_str_desc_list = ['Total Underreportaxpayer_idg Oracle', 'Refundable Credit Oracle']

outcome_table.iloc[0:len(outcome_str_desc_list), 0] = outcome_str_desc_list
outcome_table.columns.values[1:] = model_str_desc_list
outcome_table = outcome_table.set_index('Outcome')
outcome_table_t = outcome_table.transpose()

# round some rows to the nearest integer
outcome_table_t['Mean Underreportaxpayer_idg ($)'] = outcome_table_t['Mean Underreportaxpayer_idg ($)'].round().astype(int)
outcome_table_t['Mean Refundable Credit Overclaiming ($)'] = outcome_table_t['Mean Refundable Credit Overclaiming ($)'].round().astype(int)

# round other rows to the nearest tenth
one_digit_columns = ['Change Rate', 'Refundable Credit Change Rate', 'Share with Dependent(s) Reduced', 'Adjustment from HOH', 'Business Income Change Rate', 'Probabilistic Disparity', 'Linear Disparity']

for col in one_digit_columns:
    outcome_table_t[col] = outcome_table_t[col]*100
    outcome_table_t[col] = outcome_table_t[col].apply('{0:.1f}'.format)

outcome_table_t = outcome_table_t.astype(str)
outcome_table_tt = outcome_table_t.transpose()
outcome_table_tt = outcome_table_tt.reset_index()
outcome_table = outcome_table_tt

######## SECOND HALF, regressor and refundable regressor columns ############
outcome_table1 = get_outcome_table(full_research_audits_df, df4, df2, outcome_list, overlap = False)

outcome_table1 = outcome_table1.round(5)

## names associated with outcome_list (rows in final table)
outcome_str_desc_list = ['Change Rate', 'Mean Underreportaxpayer_idg ($)', 'Refundable Credit Change Rate', 'Mean Refundable Credit Overclaiming ($)', 'Share with Dependent(s) Reduced', 'Adjustment from HOH', 'Business Income Change Rate']
## name of models (columns in final table)
model_str_desc_list = ['Total Underreportaxpayer_idg Regressor', 'Refundable Credit Regressor']

outcome_table1.iloc[0:len(outcome_str_desc_list), 0] = outcome_str_desc_list
outcome_table1.columns.values[1:] = model_str_desc_list
outcome_table1 = outcome_table1.set_index('Outcome')
outcome_table_t1 = outcome_table1.transpose()

# round some rows to the nearest integer
outcome_table_t1['Mean Underreportaxpayer_idg ($)'] = outcome_table_t1['Mean Underreportaxpayer_idg ($)'].round().astype(int)
outcome_table_t1['Mean Refundable Credit Overclaiming ($)'] = outcome_table_t1['Mean Refundable Credit Overclaiming ($)'].round().astype(int)

# round other rows to the nearest tenth
one_digit_columns = ['Change Rate', 'Refundable Credit Change Rate', 'Share with Dependent(s) Reduced', 'Adjustment from HOH', 'Business Income Change Rate', 'Probabilistic Disparity', 'Linear Disparity']

for col in one_digit_columns:
    outcome_table_t1[col] = outcome_table_t1[col]*100
    outcome_table_t1[col] = outcome_table_t1[col].apply('{0:.1f}'.format)

outcome_table_t1 = outcome_table_t1.astype(str)
outcome_table_tt1 = outcome_table_t1.transpose()
outcome_table_tt1 = outcome_table_tt1.reset_index()
outcome_table1 = outcome_table_tt1

# merge together the two tables to get all 4 columns of results
outcome_table_final = outcome_table.merge(outcome_table1, how = 'left', on = 'Outcome')




############## output table as pdf #################
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
fig, ax = plt.subplots(figsize=(12,4))
ax.axis('tight')
ax.axis('off')
ax.set_title('Model Selections')
table = ax.table(cellText=outcome_table_final.values, colLabels=outcome_table_final.columns, loc='center')

table.auto_set_font_size(False)
table.set_fontsize(8)
table.auto_set_column_width(col=list(range(len(outcome_table_final.columns))))

pp = PdfPages('/REDACTED/AC_outcome_table_4cols_update.pdf')
pp.savefig(fig, bbox_inches='tight')
pp.close()